import torch
from torch import nn
from torch.nn import functional as F

from VAE.decoder import VAE_AttentionBlock, VAE_ResidualBlock

class VAE_Encoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, padding=1),
            VAE_ResidualBlock(128, 128),
            VAE_ResidualBlock(128, 128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(128, 256),
            VAE_ResidualBlock(256, 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(256, 512),
            VAE_ResidualBlock(512, 512),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_AttentionBlock(512),
            VAE_ResidualBlock(512, 512),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
            nn.Conv2d(512, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0),
        ) 
    
    def forward(self, x, noise):
        #x: (batch_size, channel, height, width)
        #noise: (batch_size, output_channel, height/8, width/8)
        for module in self:
            if getattr(module, 'stride', None) == (2, 2):
                #right and bottom padding
                x = F.pad(x, (0, 1, 0, 1))
            x = module(x)
        #(batch_size, 8, height, height/8, width/8) -> two tensors (batch_size, 4, height, height/8, width/8)
        mean, log_var = torch.chunk(x, 2, dim=1)
        log_var = torch.clamp(log_var, -30, 20)
        var = log_var.exp()
        stdev = var.sqrt()
        x = mean + stdev * noise
        x *= 0.18215
        return x
